/*
 * Copyright (c) 2023-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <tests/binaryop/util/runtime_support.h>

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/default_stream.hpp>

#include <cudf/binaryop.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/scalar/scalar.hpp>

class BinaryopTest : public cudf::test::BaseFixture {};

TEST_F(BinaryopTest, ColumnColumn)
{
  cudf::test::fixed_width_column_wrapper<int32_t> lhs{10, 20, 30, 40, 50};
  cudf::test::fixed_width_column_wrapper<int32_t> rhs{15, 25, 35, 45, 55};

  cudf::binary_operation(lhs,
                         rhs,
                         cudf::binary_operator::ADD,
                         cudf::data_type(cudf::type_to_id<int32_t>()),
                         cudf::test::get_default_stream());
}

TEST_F(BinaryopTest, ColumnScalar)
{
  cudf::test::fixed_width_column_wrapper<int32_t> lhs{10, 20, 30, 40, 50};
  cudf::numeric_scalar<int32_t> rhs{23, true, cudf::test::get_default_stream()};

  cudf::binary_operation(lhs,
                         rhs,
                         cudf::binary_operator::ADD,
                         cudf::data_type(cudf::type_to_id<int32_t>()),
                         cudf::test::get_default_stream());
}

TEST_F(BinaryopTest, ScalarColumn)
{
  cudf::numeric_scalar<int32_t> lhs{42, true, cudf::test::get_default_stream()};
  cudf::test::fixed_width_column_wrapper<int32_t> rhs{15, 25, 35, 45, 55};

  cudf::binary_operation(lhs,
                         rhs,
                         cudf::binary_operator::ADD,
                         cudf::data_type(cudf::type_to_id<int32_t>()),
                         cudf::test::get_default_stream());
}

class BinaryopPTXTest : public BinaryopTest {
 protected:
  void SetUp() override
  {
    if (!can_do_runtime_jit()) { GTEST_SKIP() << "Skipping tests that require 11.5 runtime"; }
  }
};

TEST_F(BinaryopPTXTest, ColumnColumnPTX)
{
  cudf::test::fixed_width_column_wrapper<int32_t> lhs{10, 20, 30, 40, 50};
  cudf::test::fixed_width_column_wrapper<int64_t> rhs{15, 25, 35, 45, 55};

  // c = a*a*a + b*b
  char const* ptx =
    R"***(
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_70
.address_size 64

	// .globl	_ZN8__main__7add$241Eix
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__7add$241Eix;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets7numbers14int_power_impl12$3clocals$3e13int_power$242Exx;

.visible .func  (.param .b32 func_retval0) _ZN8__main__7add$241Eix(
	.param .b64 _ZN8__main__7add$241Eix_param_0,
	.param .b32 _ZN8__main__7add$241Eix_param_1,
	.param .b64 _ZN8__main__7add$241Eix_param_2
)
{
	.reg .b32 	%r<3>;
	.reg .b64 	%rd<8>;


	ld.param.u64 	%rd1, [_ZN8__main__7add$241Eix_param_0];
	ld.param.u32 	%r1, [_ZN8__main__7add$241Eix_param_1];
	ld.param.u64 	%rd2, [_ZN8__main__7add$241Eix_param_2];
	cvt.s64.s32	%rd3, %r1;
	mul.wide.s32 	%rd4, %r1, %r1;
	mul.lo.s64 	%rd5, %rd4, %rd3;
	mul.lo.s64 	%rd6, %rd2, %rd2;
	add.s64 	%rd7, %rd6, %rd5;
	st.u64 	[%rd1], %rd7;
	mov.u32 	%r2, 0;
	st.param.b32	[func_retval0+0], %r2;
	ret;
}

)***";

  cudf::binary_operation(
    lhs, rhs, ptx, cudf::data_type(cudf::type_to_id<int32_t>()), cudf::test::get_default_stream());
  cudf::binary_operation(lhs, rhs, ptx, cudf::data_type(cudf::type_to_id<int64_t>()));
}
